[PyTorch][Fused Attn] Add support for cuDNN to return Softmax Stats always and Max when return_max_logit=True#2677
Conversation
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
for more information, see https://pre-commit.ci
Greptile SummaryThis PR updates the cuDNN-based fused attention forward pass to always return Key changes:
Two SM120 (Blackwell) regressions were found in
Confidence Score: 2/5
Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[fused_attn_arbitrary_seqlen_fwd] --> B{Aux_CTX_Tensors empty?}
B -->|Yes - allocate| C[Allocate Stats tensor\nalways, index 0]
C --> D{return_max_logit?}
D -->|Yes| E[Allocate Max tensor\nindex 1]
D -->|No| F[Skip Max]
E --> G[Allocate rng_state\nindex 2]
F --> G2[Allocate rng_state\nindex 1]
B -->|No - use existing| H[Read Stats → devPtrS1\nindex 0]
H --> I{return_max_logit?}
I -->|Yes| J[Read Max → devPtrS2\nindex 1]
I -->|No| K[Skip Max]
J --> L[Read rng_state\nindex 2]
K --> L2[Read rng_state\nindex 1]
subgraph graph_builder [Graph Builder - fused_attn_arbitrary_seqlen_fwd_impl]
M[sdpa generates O, Stats always] --> N[Stats: set_output=true\nset stride always]
N --> O{return_max_logit?}
O -->|Yes| P[Max tensor\nset_logit_max]
O -->|No| Q[Stats only\nStats_tuple = Stats, null]
P --> R[Stats_tuple = Stats, Max]
end
subgraph python [Python fused_attn_fwd - return_max_logit=True]
S[output_tensors: out, Stats, Max, rng_state, ...]
S --> T[aux_ctx_tensors = Stats + rng_state + optional]
S --> U[max_tensor = output_tensors 2 = Max]
U --> V[max_logit = amax over batch/seq dims]
T --> W[return out, aux_ctx_tensors, max_logit]
end
|
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
…27/TransformerEngine into fix_return_stats_max_cudnn
Additional Comments (1)
The public docstring still describes |
| stats = output_tensors[1] + torch.log(output_tensors[2]) | ||
| # thd: output_tensors: out [tq, h, d], Stats [tq, h, 1], Max [tq, h, 1] | ||
| # bshd: output_tensors: out [b, sq, h, d], Stats [b, h, sq, 1], Max [b, h, sq, 1] | ||
| # sbhd: output_tensors: out [sq, b, h, d], Stats [b, h, sq, 1], Max [b, h, sq, 1] (there's no typo here) |
There was a problem hiding this comment.
Do we need the "there's no typo here" :)
There was a problem hiding this comment.
I deliberately added it because I didn't believe it and checked the shapes myself :P
transformer_engine/common/include/transformer_engine/fused_attn.h
Outdated
Show resolved
Hide resolved
…eturn_stats_max_cudnn
…27/TransformerEngine into fix_return_stats_max_cudnn
…eturn_stats_max_cudnn
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
…eturn_stats_max_cudnn
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
…eturn_stats_max_cudnn
2b64738 to
e005455
Compare
|
/te-ci L2 |
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
…eturn_stats_max_cudnn
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
…27/TransformerEngine into fix_return_stats_max_cudnn Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
…eturn_stats_max_cudnn
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
for more information, see https://pre-commit.ci
…eturn_stats_max_cudnn
…27/TransformerEngine into fix_return_stats_max_cudnn
… always and `Max` when `return_max_logit=True` (#2677) * cudnn now returns Stats always and Max only with `return_max_logit=true` Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com> * fix a typo that caused a bug Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com> * update doc strings Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix more docs Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com> * fixes from the feedback Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com> * update cudnn-frontend to v1.19.1 Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com> * update the cudnn frontend Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com> * fix a wrong omission Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
… always and `Max` when `return_max_logit=True` (NVIDIA#2677) * cudnn now returns Stats always and Max only with `return_max_logit=true` Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com> * fix a typo that caused a bug Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com> * update doc strings Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix more docs Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com> * fixes from the feedback Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com> * update cudnn-frontend to v1.19.1 Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com> * update the cudnn frontend Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com> * fix a wrong omission Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Description
cuDNN recently made returning any subset of {Stats, SumExp, Max} possible. This PR adapts TE to always get
Statsfrom cuDNN andMaxtensor ifreturn_max_logit=True. (Note thatStats= log(SumExp)+Max)Type of change
Changes
Please list the changes introduced in this PR:
fused_attn_f16_arbitrary_seqlen.cuSumExptensor as it's not needed since cuDNN returnsStatsby default.generate_stats=Truewhich forces cuDNN to always returnStatstensor (needed in the backward pass)transformer_engine/pytorch/cpp_extensions/fused_attn.pyStats = log(SumExp) + Maxsince cuDNN returnsStatsdirectly and TE doesn't needSumExpfrom cuDNNChecklist: